20a1f2
@@ -18,6 +18,7 @@
 
 package org.wildfly.security.auth.realm.ldap;
 
+import org.jboss.modules.Module;
 import org.wildfly.security.SecurityFactory;
 import org.wildfly.security.auth.callback.CredentialCallback;
 import org.wildfly.security.auth.client.AuthenticationConfiguration;
@@ -25,6 +26,8 @@
import org.wildfly.security.auth.client.AuthenticationContext;
 import org.wildfly.security.auth.client.AuthenticationContextConfigurationClient;
 import org.wildfly.security.credential.PasswordCredential;
 import org.wildfly.security.credential.source.CredentialSource;
+import org.wildfly.security.manager.WildFlySecurityManager;
+import org.wildfly.security.manager.action.GetModuleClassLoaderAction;
 import org.wildfly.security.password.interfaces.ClearPassword;
 
 import static java.security.AccessController.doPrivileged;
@@ -75,6 +78,8 @@
public class SimpleDirContextFactoryBuilder {
     private Properties connectionProperties;
     private int connectTimeout = DEFAULT_CONNECT_TIMEOUT;
     private int readTimeout = DEFAULT_READ_TIMEOUT;
+    private Module targetModule;
+    private ClassLoader targetClassLoader;
 
     private static final AuthenticationContextConfigurationClient authClient = doPrivileged(AuthenticationContextConfigurationClient.ACTION);
 
@@ -246,6 +251,19 @@
public class SimpleDirContextFactoryBuilder {
         return this;
     }
 
+    /**
+     * Set module which will be used to load context factory and context.
+     *
+     * @param module - module that will be used.
+     * @return this builder
+     */
+    public SimpleDirContextFactoryBuilder setModule(final Module module) {
+        assertNotBuilt();
+        this.targetModule = module;
+
+        return this;
+    }
+
     /**
      * Build this context factory.
      *
@@ -257,7 +275,19 @@
public class SimpleDirContextFactoryBuilder {
         if (providerUrl == null) {
             throw log.noProviderUrlSet();
         }
-
+        if(this.targetModule != null){
+            if(WildFlySecurityManager.isChecking()){
+                WildFlySecurityManager.doChecked(new GetModuleClassLoaderAction(this.targetModule));
+            } else {
+                this.targetClassLoader = this.targetModule.getClassLoader();
+            }
+        } else {
+            if(WildFlySecurityManager.isChecking()){
+                WildFlySecurityManager.getClassLoaderPrivileged(this.getClass());
+            } else {
+                this.targetClassLoader = this.getClass().getClassLoader();
+            }
+        }
         built = true;
         return new SimpleDirContextFactory();
     }
@@ -369,72 +399,92 @@
public class SimpleDirContextFactoryBuilder {
         }
 
         private DirContext createDirContext(String securityPrincipal, char[] securityCredential, ReferralMode mode, SocketFactory socketFactory) throws NamingException {
-            Hashtable<String, Object> env = new Hashtable<>();
-
-            env.put(InitialDirContext.INITIAL_CONTEXT_FACTORY, initialContextFactory);
-            env.put(InitialDirContext.PROVIDER_URL, providerUrl);
-            env.put(InitialDirContext.SECURITY_AUTHENTICATION, securityAuthentication);
-            if (securityPrincipal != null) env.put(InitialDirContext.SECURITY_PRINCIPAL, securityPrincipal);
-            if (securityCredential != null) env.put(InitialDirContext.SECURITY_CREDENTIALS, securityCredential);
-            env.put(InitialDirContext.REFERRAL, mode == null ? ReferralMode.IGNORE.getValue() : mode.getValue());
-            if (socketFactory != null) env.put(SOCKET_FACTORY, ThreadLocalSSLSocketFactory.class.getName());
-            env.put(CONNECT_TIMEOUT, Integer.toString(connectTimeout));
-            env.put(READ_TIMEOUT, Integer.toString(readTimeout));
-
-            // set any additional connection property
-            if (connectionProperties != null) {
-                for (Object key : connectionProperties.keySet()) {
-                    Object value = connectionProperties.get(key.toString());
-
-                    if (value != null) {
-                        env.put(key.toString(), value);
+            final ClassLoader oldClassLoader = setClassLoaderTo(targetClassLoader);
+            try{
+                Hashtable<String, Object> env = new Hashtable<>();
+
+                env.put(InitialDirContext.INITIAL_CONTEXT_FACTORY, initialContextFactory);
+                env.put(InitialDirContext.PROVIDER_URL, providerUrl);
+                env.put(InitialDirContext.SECURITY_AUTHENTICATION, securityAuthentication);
+                if (securityPrincipal != null) env.put(InitialDirContext.SECURITY_PRINCIPAL, securityPrincipal);
+                if (securityCredential != null) env.put(InitialDirContext.SECURITY_CREDENTIALS, securityCredential);
+                env.put(InitialDirContext.REFERRAL, mode == null ? ReferralMode.IGNORE.getValue() : mode.getValue());
+                if (socketFactory != null) env.put(SOCKET_FACTORY, ThreadLocalSSLSocketFactory.class.getName());
+                env.put(CONNECT_TIMEOUT, Integer.toString(connectTimeout));
+                env.put(READ_TIMEOUT, Integer.toString(readTimeout));
+
+                // set any additional connection property
+                if (connectionProperties != null) {
+                    for (Object key : connectionProperties.keySet()) {
+                        Object value = connectionProperties.get(key.toString());
+
+                        if (value != null) {
+                            env.put(key.toString(), value);
+                        }
                     }
                 }
-            }
 
-            if (log.isDebugEnabled()) {
-                log.debugf("Creating [" + InitialDirContext.class + "] with environment:");
-                env.forEach((key, value) -> {
-                    if (value instanceof Object[]) {
-                        log.debugf("    Property [%s] with values %s", key, Arrays.deepToString((Object[]) value));
-                    } else {
-                        log.debugf("    Property [%s] with value [%s]", key, value);
-                    }
-                });
-            }
+                if (log.isDebugEnabled()) {
+                    log.debugf("Creating [" + InitialDirContext.class + "] with environment:");
+                    env.forEach((key, value) -> {
+                        if (value instanceof Object[]) {
+                            log.debugf("    Property [%s] with values %s", key, Arrays.deepToString((Object[]) value));
+                        } else {
+                            log.debugf("    Property [%s] with value [%s]", key, value);
+                        }
+                    });
+                }
 
-            InitialLdapContext initialContext;
+                InitialLdapContext initialContext;
 
-            try {
-                if (socketFactory != null) ThreadLocalSSLSocketFactory.set(socketFactory);
-                initialContext = new InitialLdapContext(env, null);
-            } catch (NamingException ne) {
-                log.debugf(ne, "Could not create [%s]. Failed to connect to LDAP server.", InitialLdapContext.class);
-                throw ne;
-            } finally {
-                if (socketFactory != null) ThreadLocalSSLSocketFactory.unset();
-            }
+                try {
+                    if (socketFactory != null) ThreadLocalSSLSocketFactory.set(socketFactory);
+                    initialContext = new InitialLdapContext(env, null);
+                } catch (NamingException ne) {
+                    log.debugf(ne, "Could not create [%s]. Failed to connect to LDAP server.", InitialLdapContext.class);
+                    throw ne;
+                } finally {
+                    if (socketFactory != null) ThreadLocalSSLSocketFactory.unset();
+                }
 
-            log.debugf("[%s] successfully created. Connection established to LDAP server.", initialContext);
+                log.debugf("[%s] successfully created. Connection established to LDAP server.", initialContext);
 
-            return new DelegatingLdapContext(initialContext, this::returnContext, socketFactory);
+                return new DelegatingLdapContext(initialContext, this::returnContext, socketFactory);
+            } finally{
+                setClassLoaderTo(oldClassLoader);
+            }
         }
 
         @Override
         public void returnContext(DirContext context) {
+
             if (context == null) {
                 return;
             }
 
             if (context instanceof InitialDirContext) {
+                final ClassLoader oldClassLoader = setClassLoaderTo(targetClassLoader);
                 try {
                     context.close();
                     log.debugf("Context [%s] was closed. Connection closed or just returned to the pool.", context);
                 } catch (NamingException ignored) {
+                } finally {
+                    setClassLoaderTo(oldClassLoader);
                 }
             }
         }
 
+        private ClassLoader setClassLoaderTo(final ClassLoader targetClassLoader){
+            ClassLoader current = null;
+            if(WildFlySecurityManager.isChecking()){
+                current = WildFlySecurityManager.getCurrentContextClassLoaderPrivileged();
+                WildFlySecurityManager.setCurrentContextClassLoaderPrivileged(targetClassLoader);
+            } else {
+                current = getClass().getClassLoader();
+                Thread.currentThread().setContextClassLoader(targetClassLoader);
+            }
+            return current;
+        }
     }
 
 }
